from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import pandas as pd


def build_monthly_panel(owid_path: Path, out_path: Path, start: str, end: str) -> pd.DataFrame:
    usecols = [
        "iso_code",
        "continent",
        "location",
        "date",
        "new_cases",
        "new_deaths",
        "stringency_index",
        "people_vaccinated_per_hundred",
        "people_fully_vaccinated_per_hundred",
        "excess_mortality_cumulative_absolute",
        "excess_mortality_cumulative_per_million",
        "population",
        "gdp_per_capita",
        "population_density",
        "median_age",
        "aged_65_older",
        "hospital_beds_per_thousand",
        "diabetes_prevalence",
        "handwashing_facilities",
        "life_expectancy",
        "human_development_index",
        "extreme_poverty",
    ]

    df = pd.read_csv(
        owid_path,
        usecols=lambda c: c in usecols,
        parse_dates=["date"],
        low_memory=False,
    )

    df = df[df["iso_code"].notna()].copy()
    df = df[~df["iso_code"].astype(str).str.startswith("OWID_")].copy()

    df = df[(df["date"] >= pd.Timestamp(start)) & (df["date"] <= pd.Timestamp(end))].copy()
    df["month"] = df["date"].dt.to_period("M").dt.to_timestamp(how="start")

    df = df.sort_values(["iso_code", "date"])

    monthly = (
        df.groupby(["iso_code", "continent", "location", "month"], as_index=False)
        .agg(
            new_cases=("new_cases", "sum"),
            new_deaths=("new_deaths", "sum"),
            stringency_index=("stringency_index", "mean"),
            people_vaccinated_per_hundred=("people_vaccinated_per_hundred", "max"),
            people_fully_vaccinated_per_hundred=("people_fully_vaccinated_per_hundred", "max"),
            excess_cum_abs=("excess_mortality_cumulative_absolute", "max"),
            excess_cum_pm=("excess_mortality_cumulative_per_million", "max"),
            population=("population", "max"),
            gdp_per_capita=("gdp_per_capita", "max"),
            population_density=("population_density", "max"),
            median_age=("median_age", "max"),
            aged_65_older=("aged_65_older", "max"),
            hospital_beds_per_thousand=("hospital_beds_per_thousand", "max"),
            diabetes_prevalence=("diabetes_prevalence", "max"),
            handwashing_facilities=("handwashing_facilities", "max"),
            life_expectancy=("life_expectancy", "max"),
            human_development_index=("human_development_index", "max"),
            extreme_poverty=("extreme_poverty", "max"),
        )
        .sort_values(["iso_code", "month"])
    )

    monthly["excess_deaths"] = monthly.groupby("iso_code")["excess_cum_abs"].diff()
    monthly["excess_deaths_pm"] = monthly.groupby("iso_code")["excess_cum_pm"].diff()

    monthly["reported_deaths_pm"] = (monthly["new_deaths"] / monthly["population"]) * 1_000_000
    monthly["reported_cases_pm"] = (monthly["new_cases"] / monthly["population"]) * 1_000_000

    monthly["gap_excess_minus_reported"] = monthly["excess_deaths"] - monthly["new_deaths"]
    monthly["gap_pm_excess_minus_reported"] = monthly["excess_deaths_pm"] - monthly["reported_deaths_pm"]

    monthly["excess_deaths_clip0"] = monthly["excess_deaths"].clip(lower=0)
    monthly["excess_deaths_pm_clip0"] = monthly["excess_deaths_pm"].clip(lower=0)
    monthly["gap_pm_clip0"] = monthly["excess_deaths_pm_clip0"] - monthly["reported_deaths_pm"]

    monthly.replace([np.inf, -np.inf], np.nan, inplace=True)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    monthly.to_parquet(out_path, index=False)
    return monthly


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--owid", type=Path, default=Path("data/owid/owid-covid-data.csv"))
    parser.add_argument("--out", type=Path, default=Path("outputs/data/owid_monthly.parquet"))
    parser.add_argument("--start", type=str, default="2020-01-01")
    parser.add_argument("--end", type=str, default="2024-12-31")
    args = parser.parse_args()

    monthly = build_monthly_panel(args.owid, args.out, args.start, args.end)
    print(f"Wrote {args.out} rows={len(monthly):,} cols={len(monthly.columns):,}")


if __name__ == "__main__":
    main()
